
















if __name__ == '__main__':
    import sys
    sys.path.append('../../')


import time
import datetime
import requests
import json
import hmac
import base64
from quant import Order, config
from quant.markets import Functions
from quant.markets.functions import get_asset_value, get_contract_size
from quant.utils import logging, Iota, Timer2
from quant.const import *

from quant.exchanges import add_factory
from quant.exchanges.util import (Socket, RecentChecker2, update_balance_default, prepare_placing,prepare_canceling,
                                  parse_params, simulate_deal_spot, simulate_deal_usdt_swap_2,
                                  next_cid_head, avoid_sci_num)
from quant.exchanges.basics import ChannelBasic, HandlerBasic, ApiBasic, SpiBasic


class OkexChannel(ChannelBasic):
    def init(self):
        self.socket = Socket(ws_host, self.on_open, self.on_message, self.on_close)
        Timer2(self._ping, 20).start()

    def subscribe_book(self, ws):
        contract = format_symbol(self.symbol)
        if self.frequency == DataFrequency.High:
            data = {"op": "subscribe", "args": [{"channel": "books-l2-tbt", "instId": contract}]}
        else:
            data = {"op": "subscribe", "args": [{"channel": "books", "instId": contract}]}
        msg = json.dumps(data)
        ws.send(msg)

    def subscribe_trade(self, ws):
        contract = format_symbol(self.symbol)
        data = {"op": "subscribe", "args": [{"channel": "trades", "instId": contract}]}
        msg = json.dumps(data)
        ws.send(msg)

    def _ping(self):
        self.socket.send('ping')


class OkexHandler(HandlerBasic):
    def init(self):
        self.data_id_checker = RecentChecker2(50)

    def process_book(self, routing_key, recv_time, raw):
        if raw == 'pong':
            return

        if raw == NAME_CHANNEL_CLOSE:
            return self.process_close(routing_key, recv_time)

        data = json.loads(raw)
        if 'event' in data and data['event'] == 'subscribe':
            return self.info_welcome()

        # if 'action' not in data:
        #     logging.error('Okex channel error: {}'.format(data))
        #     return

        action = data['action']
        data = data['data'][0]
        str_ts = data['ts']
        server_time = float(str_ts) / 1000
        self.markets.info_engine.push_process_recv(str_ts, server_time, recv_time)

        if not self.data_id_checker.is_new(str_ts):
            return

        bids = data['bids']
        asks = data['asks']
        book = self.book

        if action == 'snapshot':
            book['buy'] = [[float(p), float(v)] for p, v, *_ in bids]
            book['sell'] = [[float(p), float(v)] for p, v, *_ in asks]
        elif action == 'update':
            for p, v, *_ in bids:
                book['buy', float(p)] = float(v)
            for p, v, *_ in asks:
                book['sell', float(p)] = float(v)
        else:
            logging.error('OkexChannel.process_book(): action {}'.format(action))

        self.check_book(book)
        self.push_book(routing_key, recv_time, server_time, None, book)

    def process_trade(self, routing_key, recv_time, raw):
        if raw == 'pong':
            return

        if raw == NAME_CHANNEL_CLOSE:
            return self.process_close(routing_key, recv_time)

        data = json.loads(raw)
        if 'event' in data and data['event'] == 'subscribe':
            return self.info_welcome()

        # if 'action' not in data:
        #     logging.error('Okex channel error: {}'.format(data))
        #     return

        data = json.loads(raw)
        data = data['data']
        server_id = data[-1]['tradeId']
        server_time = int(data[-1]['ts']) / 1000
        self.markets.info_engine.push_process_recv(server_id, server_time, recv_time)

        if not self.data_id_checker.is_new(server_id):
            return

        trade_list = [[d['side'], float(d['px']), float(d['sz'])] for d in data]
        self.push_trade(routing_key, recv_time, server_time, server_id, trade_list)


class OkexApi(ApiBasic):
    pass_phrase: str
    rest_executor = None
    ws_executor = None
    executor = None

    def init(self):
        self.pass_phrase = self.keys['pass_phrase']
        self.rest_executor = RestExecutor(self)
        self.ws_executor = None
        self.executor = self.rest_executor

    def set_ws_mode(self, ws=True):
        if ws:
            if self.ws_executor is None:
                self.ws_executor = WsExecutor(self)
            self.executor = self.ws_executor
        else:
            self.executor = self.rest_executor

    def query_balance(self, symbol=None):
        path = '/api/v5/account/balance'
        utc = get_utc_time()
        req = build_req('GET', rest_host, path, utc, self.api_key, self.secret_key, self.pass_phrase)
        self.requesting.request_async(symbol, req, self._on_balance)

    def query_all_balance(self):
        path = '/api/v5/account/balance'
        utc = get_utc_time()
        req = build_req('GET', rest_host, path, utc, self.api_key, self.secret_key, self.pass_phrase)
        self.requesting.request_async(None, req, self._on_all_balance)

    def query_margin(self, symbol=None):
        path = '/api/v5/account/balance'
        utc = get_utc_time()
        req = build_req('GET', rest_host, path, utc, self.api_key, self.secret_key, self.pass_phrase)
        self.requesting.request_async(symbol, req, self._on_margin)

    def query_all_margin(self):
        path = '/api/v5/account/balance'
        utc = get_utc_time()
        req = build_req('GET', rest_host, path, utc, self.api_key, self.secret_key, self.pass_phrase)
        self.requesting.request_async('all', req, self._on_margin)

    def query_position(self, symbol=None):
        for s in self._iter_default_symbol(symbol):
            path = '/api/v5/account/positions'
            params = {'instId': format_symbol(s)}
            req = build_req('GET', rest_host, path, get_utc_time(), self.api_key, self.secret_key, self.pass_phrase, params)
            self.requesting.request_async(s, req, self._on_position)

    def query_all_position(self):
        path = '/api/v5/account/positions'
        req = build_req('GET', rest_host, path, get_utc_time(), self.api_key, self.secret_key, self.pass_phrase)
        self.requesting.request_async(None, req, self._on_all_position)

    def query_open_orders(self, symbol=None):
        for s in self._iter_default_symbol(symbol):
            path = '/api/v5/trade/orders-pending'
            param = {'instId': format_symbol(s)}
            utc = get_utc_time()
            req = build_req('GET', rest_host, path, utc, self.api_key, self.secret_key, self.pass_phrase, param)
            self.requesting.request_async(symbol, req, self._on_open_orders)

    def query_all_open_orders(self):
        path = '/api/v5/trade/orders-pending'
        utc = get_utc_time()
        req = build_req('GET', rest_host, path, utc, self.api_key, self.secret_key, self.pass_phrase)
        self.requesting.request_async(None, req, self._on_all_open_orders)

    def place_order(self, order):  # offset = 'close' 不好使。。。
        return self.executor.place_order(order)

    def cancel_order(self, order):
        return self.executor.cancel_order(order)

    def cancel_all(self):
        raise NotImplementedError('cancel_all() not implemented')

    def _on_balance(self, symbol, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_balance', symbol, response)

        if symbol is None:
            interests = self._asset_added
        else:
            interests = {*symbol.split('/')}

        balance = self.account.balance
        j = response.json()
        data = j['data']
        if len(data) > 1:
            logging.critical('okex balance len > 1')
        data = data[0]['details']
        for dic in data:
            asset = dic['ccy'].lower()
            if asset in interests:
                free = float(dic['availEq'])
                total = float(dic['cashBal'])
                frozen = total - free
                update_balance_default(balance, asset, free, frozen)

        self._put_data(UserEvent.Balance, balance)
        return balance

    def _on_all_balance(self, _, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_all_balance', _, response)

        balance = {}
        j = response.json()

        data = j['data']
        if len(data) > 1:
            logging.critical('okex balance len > 1')

        data = data[0]['details']
        for dic in data:
            asset = dic['ccy'].lower()

            total = float(dic['cashBal'])
            if total == 0:
                continue

            free = float(dic['availEq'])
            frozen = total - free
            update_balance_default(balance, asset, free, frozen)

        self.account.balance.clear()
        self.account.balance.update(balance)

        self._put_data(UserEvent.Balance, balance)
        return balance

    def _on_margin(self, symbol, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_margin', symbol, response)

        margin = self.account.margin
        j = response.json()
        data = j['data']

        if len(data) > 1:
            logging.critical('okex margin len > 1')
        data = data[0]['details']
        for dic in data:
            asset = dic['ccy'].lower()
            if asset == 'usdt':
                margin['usdt'] = float(dic['cashBal'])

        self._put_data(UserEvent.Margin, margin)

    def _on_position(self, symbol, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_position', symbol, response)

        position = self.account.position
        j = response.json()
        for dic in j['data']:
            symbol = resume_symbol(dic['instId'])
            pos = float(dic['pos'])
            long = 0
            short = 0
            if pos > 0:
                long = pos
            else:
                short = -pos
            position[symbol] = {'symbol': symbol, 'position': pos, 'long': long, 'short': short}

        self._put_data(UserEvent.Position, position)
        return position

    def _on_all_position(self, param, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_position', 'None', response)

        position = {}
        j = response.json()
        for dic in j['data']:
            symbol = resume_symbol(dic['instId'])
            pos = float(dic['pos'])
            long = 0
            short = 0
            if pos > 0:
                long = pos
            else:
                short = -pos
            position[symbol] = {'symbol': symbol, 'position': pos, 'long': long, 'short': short}

        account_position = self.account.position
        account_position.clear()
        account_position.update(position)
        self._put_data(UserEvent.Position, position)

    def _on_open_orders(self, symbol, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_open_orders', symbol, response)

        open_orders = []
        j = response.json()
        for dic in j['data']:
            cid = dic['clOrdId']
            if cid == '':
                cid = self._create_cid()

            order = Order(
                symbol=resume_symbol(dic['instId']),
                side=dic['side'],
                price=float(dic['px']),
                amount=float(dic['sz']) - float(dic['accFillSz']),
                client_id=cid,
                order_id=dic['ordId'],
                # offset=offset_map[dic['reduceOnly']],
                status=OrderStatus.Pending,
            )
            open_orders.append(order)

        self.account.order_processor.process_open_orders(open_orders)
        self._put_data(UserEvent.OpenOrders, self.account.orders)

    def _on_all_open_orders(self, _, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_all_open_orders', _, response)

        open_orders = []
        j = response.json()
        for dic in j['data']:
            cid = dic['clOrdId']
            if cid == '':
                cid = self._create_cid()

            order = Order(
                symbol=resume_symbol(dic['instId']),
                side=dic['side'],
                price=float(dic['px']),
                amount=float(dic['sz']) - float(dic['accFillSz']),
                client_id=cid,
                order_id=dic['ordId'],
                # offset=offset_map[dic['reduceOnly']],
                status=OrderStatus.Pending,
            )
            open_orders.append(order)

        self.account.order_processor.process_open_orders(open_orders)
        self._put_data(UserEvent.OpenOrders, self.account.orders)

    def get_trading_config(self):
        path = '/api/v5/account/config'
        req = build_req('GET', rest_host, path, get_utc_time(), self.api_key, self.secret_key, self.pass_phrase)
        result = self.requesting.request(req)
        return result.json()

    def set_net_pos(self, net=True):
        path = '/api/v5/account/set-position-mode'
        params = {
            'posMode': 'net_mode' if net else 'long_short_mode'
        }
        req = build_req('POST', rest_host, path, get_utc_time(), self.api_key, self.secret_key, self.pass_phrase, body=params)
        result = self.requesting.request(req)
        return result.text


class RestExecutor:
    def __init__(self, api):
        self.api = api

    def place_order(self, order):  # offset = 'close' 不好使。。。
        api = self.api
        order = prepare_placing(order, api._create_cid())
        path = '/api/v5/trade/order'

        td_mode = 'cross' if '.' in order.symbol else 'cash'
        params = {
            'instId': format_symbol(order.symbol),
            'side': order.side,
            'px': avoid_sci_num(order.price),
            'sz': order.amount,
            'tdMode': td_mode,
            'clOrdId': order.client_id,
            'ordType': order_type_map[order.order_type],
            # 'reduceOnly': True if order.offset == Offset.Close else False,
        }

        if order.order_type == OrderType.Market:
            del params['px']

        req = build_req('POST', rest_host, path, get_utc_time(), api.api_key, api.secret_key, api.pass_phrase, body=params)
        api.account.order_processor.process_place_operation(order)
        api.requesting.request_async(order, req, self._on_place)
        return order.client_id

    def cancel_order(self, order):
        api = self.api
        order = prepare_canceling(order)

        path = '/api/v5/trade/cancel-order'
        params = {
            'instId': format_symbol(order.symbol),
            'ordId': order.order_id,
        }

        req = build_req('POST', rest_host, path, get_utc_time(), api.api_key, api.secret_key, api.pass_phrase, body=params)
        api.account.order_processor.process_cancel_operation(order)
        api.requesting.request_async(order, req, self._on_cancel)

    def _on_place(self, order, response):
        api = self.api
        order = order.copy()

        if int(response.status_code) // 100 != 2:
            api._fail_request('place_order', order, response)
            order.status = OrderStatus.PlaceFailed
        else:
            j = response.json()
            if int(j['code']) == 0:
                order.order_id = j['data'][0]['ordId']
                order.status = OrderStatus.Pending
                if order.order_type == OrderType.Market:
                    order.status = OrderStatus.FullyFilled
            else:
                api._fail_request('place_order', order, response)
                order.status = OrderStatus.PlaceFailed

        order = api.account.order_processor.update_order(order)
        api._put_data(UserEvent.Order, order)

    def _on_cancel(self, order, response):
        api = self.api
        order = order.copy()
        status_code = int(response.status_code)

        if status_code // 100 != 2:
            api._fail_request('cancel_order', order, response)

        if status_code == -1:
            order.status = OrderStatus.CancelFailed
        else:
            order.status = OrderStatus.Canceled

        order = api.account.order_processor.update_order(order)
        api._put_data(UserEvent.Order, order)


class WsExecutor:
    def __init__(self, api):
        self.api = api
        self.socket = Socket(ws_host_user, self._on_open, self._on_message, lambda x: None)
        self.socket.register_ping(10, build_ping)
        self.socket.connect()

        self._msg_head = next_cid_head()
        self._msg_iota = Iota()
        self._placing = {}  # msg_id: (ts, order)
        self._canceling = {}  # msg_id: (ts, order)
        Timer2(self._kill_timeout, 1).start()

    def place_order(self, order):  # offset = 'close' 不好使。。。
        api = self.api
        msg_id = '{}{}'.format(self._msg_head, self._msg_iota.next())
        order = prepare_placing(order, api._create_cid())
        self._placing[msg_id] = (time.time(), order)

        td_mode = 'cross' if '.' in order.symbol else 'cash'
        params = {
            'instId': format_symbol(order.symbol),
            'side': order.side,
            'px': avoid_sci_num(order.price),
            'sz': order.amount,
            'tdMode': td_mode,
            'clOrdId': order.client_id,
            'ordType': order_type_map[order.order_type],
            # 'reduceOnly': True if order.offset == Offset.Close else False,
        }
        if order.order_type == OrderType.Market:
            del params['px']

        data = {
            'id': msg_id,
            'op': 'order',
            'args': [params]
        }
        msg = json.dumps(data)

        api.account.order_processor.process_place_operation(order)
        self.socket.assert_send(msg)

        return order.client_id

    def cancel_order(self, order):
        api = self.api
        msg_id = '{}{}'.format(self._msg_head, self._msg_iota.next())
        order = prepare_canceling(order)
        self._canceling[msg_id] = (time.time(), order)

        params = {
            'instId': format_symbol(order.symbol),
            'ordId': order.order_id,
        }
        data = {
            'id': msg_id,
            'op': 'cancel-order',
            'args': [params]
        }

        msg = json.dumps(data)
        api.account.order_processor.process_cancel_operation(order)
        self.socket.assert_send(msg)

    def _on_place(self, data):
        msg_id = data['id']
        if msg_id not in self._placing:
            return

        api = self.api
        order = self._placing.pop(msg_id)[1]
        order = order.copy()
        code = data['code']

        if int(code) == 0:
            data = data['data'][0]
            order.order_id = data['ordId']
            order.status = OrderStatus.Pending
            if order.order_type == OrderType.Market:
                order.status = OrderStatus.FullyFilled
        else:
            response = self._build_fail_response(200, str(data))
            api._fail_request('place_order', order, response)
            order.status = OrderStatus.PlaceFailed

        order = api.account.order_processor.update_order(order)
        api._put_data(UserEvent.Order, order)

    def _on_cancel(self, data):
        msg_id = data['id']
        if msg_id not in self._canceling:
            return

        api = self.api
        order = self._canceling.pop(msg_id)[1]
        order = order.copy()
        code = data['code']
        data = data['data'][0]

        if int(code) != 0:
            response = self._build_fail_response(200, str(data))
            api._fail_request('cancel_order', order, response)

        order.status = OrderStatus.Canceled
        order = api.account.order_processor.update_order(order)
        api._put_data(UserEvent.Order, order)

    def _on_open(self, ws):
        self._login(ws)

    def _login(self, ws):
        api = self.api
        timestamp = time.time()
        sign_message = f'{timestamp}GET/users/self/verify'
        mac = hmac.new(bytes(api.secret_key, encoding='utf8'), bytes(sign_message, encoding='utf-8'), digestmod='sha256')
        d = mac.digest()
        sign = base64.b64encode(d)
        sign = sign.decode('utf-8')

        data = {
            "op": "login",
            "args":
                [
                    {
                        "apiKey": api.api_key,
                        "passphrase": api.pass_phrase,
                        "timestamp": timestamp,
                        "sign": sign
                    }
                ]
        }
        msg = json.dumps(data)
        ws.send(msg)

    def _on_message(self, ws, message):
        if message == 'pong':
            return

        data = json.loads(message)
        if 'event' in data:
            event = data['event']
            if event == 'login':
                return

        if 'op' not in data:
            err = '"op" not in message:{}'.format(data)
            raise Exception(err)

        op = data['op']
        if op == 'order':
            self._on_place(data)
        elif op == 'cancel-order':
            self._on_cancel(data)
        else:
            raise Exception('okex.ws_exc met unknown op: {}'.format(data['op']))

    def _build_fail_response(self, code, text):
        try:
            content = '"{}"'.format(text)
            content = content.encode()
        except Exception as _:
            err = 'Requesting._build_fail_response() err, cannot encode {}'.format(text)
            logging.error(err)
            content = b'{}'

        resp = requests.Response()
        resp._content = content
        resp.status_code = code
        return resp

    def _kill_timeout(self):
        timeout = config.rest_timeout
        too_late = time.time() - timeout
        placing = self._placing
        canceling = self._canceling
        kill_place = [msg_id for msg_id, (ts, order) in placing.items() if ts < too_late]
        kill_cancel = [msg_id for msg_id, (ts, order) in canceling.items() if ts < too_late]

        # if kill_place or kill_cancel:
        #     print('placing:', self._placing)
        #     print('canceling:', self._canceling)

        api = self.api
        for msg_id in kill_place:
            ts, order = placing.pop(msg_id)
            order = order.copy()

            response = self._build_fail_response(-1, 'ws_executor timeout')
            api._fail_request('place_order', order, response)

            order.status = OrderStatus.PlaceFailed
            order = api.account.order_processor.update_order(order)
            api._put_data(UserEvent.Order, order)

        for msg_id in kill_cancel:
            ts, order = canceling.pop(msg_id)
            order = order.copy()

            response = self._build_fail_response(-1, 'ws_executor timeout')
            api._fail_request('cancel_order', order, response)

            order.status = OrderStatus.CancelFailed
            order = api.account.order_processor.update_order(order)
            api._put_data(UserEvent.Order, order)


class OkexSpi(SpiBasic):
    pass_phrase: str

    def init(self):
        self.pass_phrase = self.keys['pass_phrase']

    def add_symbol(self, symbol):
        self.connect_once()
        super().add_symbol(symbol)
        if self._socket.is_connected():
            self._on_login(self._socket)

    def _create_socket(self):
        socket = Socket(ws_host_user, self._on_open, self._on_message, self._on_close)
        socket.register_ping(15, build_ping)
        return socket

    def _on_open(self, ws):
        self._login(ws)

    def _on_message(self, ws, message):
        if message == 'pong':
            return

        data = json.loads(message)
        if 'event' in data:
            event = data['event']
            if event == 'login':
                self._on_login(ws)
                return
            if event == 'subscribe':
                return

        channel = data['arg']['channel']
        data = data['data']
        if channel == 'orders':
            self._handle_order(data)
        elif channel == 'account':
            self._handle_balance(data)
        elif channel == 'positions':
            self._handle_position(data)
        else:
            raise NotImplementedError('okex spi event {} not implemented'.format(channel))

    def _on_login(self, ws):
        for asset in self._asset_added:
            data = {"op": "subscribe", "args": [{"channel": "account", "ccy": asset.upper()}]}
            ws.send(json.dumps(data))

        for symbol in self._symbol_added:
            arg = {"channel": "orders", "instType": get_inst_type(symbol), "instId": format_symbol(symbol)}
            data = {"op": "subscribe", "args": [arg]}
            ws.send(json.dumps(data))

        for symbol in self._symbol_added:
            if '.' in symbol:
                args = {"channel": "positions", "instType": get_inst_type(symbol), "instId": format_symbol(symbol)}
                data = {"op": "subscribe", "args": [args]}
                ws.send(json.dumps(data))

    def _login(self, ws):
        timestamp = time.time()
        sign_message = f'{timestamp}GET/users/self/verify'
        mac = hmac.new(bytes(self.secret_key, encoding='utf8'), bytes(sign_message, encoding='utf-8'), digestmod='sha256')
        d = mac.digest()
        sign = base64.b64encode(d)
        sign = sign.decode('utf-8')

        data = {
            "op": "login",
            "args":
                [
                    {
                        "apiKey": self.api_key,
                        "passphrase": self.pass_phrase,
                        "timestamp": timestamp,
                        "sign": sign
                    }
                ]
        }
        msg = json.dumps(data)
        ws.send(msg)

    def _handle_order(self, data):
        for dic in data:
            cid = dic['clOrdId']
            if cid == '':
                cid = self._create_cid()

            if dic['px'] != '':
                price = float(dic['px'])
            elif dic['fillPx'] != '':
                price = float(dic['fillPx'])
            elif cid in self.account.orders:
                price = self.account.orders[cid].price
            else:
                continue

            fill_size = dic['fillSz']
            if fill_size == '':
                fill_size = 0

            order = Order(
                symbol=resume_symbol(dic['instId']),
                side=dic['side'],
                price=price,
                amount=float(dic['sz']) - float(fill_size),
                client_id=cid,
                order_id=dic['ordId'],
                status=status_map[dic['state']]
            )
            order = self.account.order_processor.update_order(order)
            self._put_data(UserEvent.Order, order)

            if dic['state'] in deal_status:
                deal = float(dic['fillSz'])
                if deal != 0:
                    self._put_deal(order, deal)

    def _handle_position(self, data):
        position = self.account.position
        for dic in data:
            symbol = resume_symbol(dic['instId'])
            pos = float(dic['pos'])
            long = 0
            short = 0
            if pos > 0:
                long = pos
            else:
                short = -pos
            position[symbol] = {'symbol': symbol, 'position': pos, 'long': long, 'short': short}

        self._put_data(UserEvent.Position, position)

    def _handle_balance(self, data):
        if len(data) > 1:
            logging.critical('okex balance len > 1')

        balance = self.account.balance
        data = data[0]['details']
        for dic in data:
            asset = dic['ccy'].lower()
            if asset in self._asset_added:
                free = float(dic['availEq'])
                total = float(dic['cashBal'])
                frozen = total - free
                update_balance_default(balance, asset, free, frozen)

        margin = self.account.margin
        usdt = balance.get('usdt', None)
        if usdt is not None:
            margin['usdt'] = usdt['free'] + usdt['frozen']

        self._put_data(UserEvent.Balance, balance)
        self._put_data(UserEvent.Margin, margin)


class OkexFunctions(Functions):
    def get_all_ticker(self):
        path = '/api/v5/market/tickers'
        url = rest_host + path

        data = []
        for inst in ['SPOT', 'SWAP']:
            resp = self.session.get(url, params={'instType': inst}, timeout=5)
            d = resp.json()['data']
            data += d

        result = {}
        valid = ['usd', 'usdt']

        for d in data:
            inst = d['instId']
            money = inst.split('-')[1].lower()
            if money not in valid:
                continue
            money_value = get_asset_value(money)

            s = resume_symbol(inst)
            try:
                price = float(d['last'])
            except:
                print('okex price:', float(d['last']))
                price = 0
            if '.' in s:
                value_factor = price * money_value
            else:
                value_factor = money_value

            result[s] = {
                'price': price,
                'vol': float(d['volCcy24h']) * value_factor,
                'bid': float(d['bidPx']),
                'ask': float(d['askPx']),
                }
        return result

    def get_all_precision(self):
        mp = {}
        for inst in ['SPOT', 'SWAP']:
            mp.update(self._get_inst_attr(inst))

        result = {s: (d['tickSz'], d['lotSz']) for s, d in mp.items() if 'usd' in s}
        return result

    def get_value_factor(self, symbol):
        if '.' not in symbol:
            contract_size = 1
        else:
            contract_size = get_contract_size(Exchange.Okex, symbol)

        asset_value = get_asset_value(symbol.split('/')[0])
        return contract_size * asset_value

    def get_all_contract_size(self):
        inst_types = ['SWAP']
        attr = {}
        for typ in inst_types:
            mp = self._get_inst_attr(typ)
            attr.update(mp)

        result = {}
        for s, d in attr.items():
            uly, deliver = s.split('.')
            coin, money = uly.split('/')

            if money == 'usdt':
                par = float(d['par'])
            elif money == 'usd':
                par = float(d['par']) * get_asset_value('usd') / get_asset_value(coin)
            elif money == 'usdc':
                par = None
            else:
                par = None
                logging.error('Okex contract size for {} not implemented'.format(s))
            result[s] = par

        all_ticker = self.get_all_ticker()
        for s in all_ticker:
            if s.endswith('/usdt'):
                result[s] = 1

        return result

    def get_recent_trade(self, symbol, limit=500):
        path = '/api/v5/market/trades'
        url = rest_host + path
        param = {
            'instId': format_symbol(symbol),
            'limit': str(limit),
        }
        response = self.session.get(url, params=param, timeout=5)
        js = response.json()
        if 'data' not in js:
            print('data not in response:')
            print(js)
            return []

        result = [(d['tradeId'], d['side'], float(d['px']), float(d['sz']), int(d['ts'])/1000) for d in js['data']]
        return result

    def _get_inst_attr(self, ins_type):
        path = '/api/v5/public/instruments'
        url = rest_host + path
        param = {'instType': ins_type}
        response = self.session.get(url, params=param, timeout=5)
        js = response.json()

        result = {}
        for d in js['data']:
            s = resume_symbol(d['instId'])
            result[s] = {
                'tickSz': d['tickSz'],
                'lotSz': d['lotSz'],
                'par': d['ctVal'],
            }

        return result

    def simulate_deal(self, order, deal_amount, mid_price, account):
        symbol = order.symbol

        if symbol.endswith('/usdt'):
            return simulate_deal_spot(order, deal_amount, mid_price, account)

        if order.symbol.endswith('/usdt.swap'):
            size_factor = get_contract_size('Okex', order.symbol)
            return simulate_deal_usdt_swap_2(order, deal_amount, mid_price, account, size_factor)

        err = 'Okex.simultate_deal() for {} not implemented'.format(symbol)
        raise Exception(err)


def format_symbol(symbol):
    if '.' in symbol:
        s, deliver = symbol.split('.')
        deliver = format_deliver(deliver)
    else:
        s = symbol
        deliver = ''

    contract = s.replace('/', '-').upper()
    if deliver:
        contract += '-{}'.format(deliver)
    return contract


def resume_symbol(contract):
    a, b, *c = contract.split('-')
    symbol = '{}/{}'.format(a, b)
    symbol = symbol.lower()
    if len(c) > 0:
        deliver = resume_deliver(c[0])
        symbol = '{}.{}'.format(symbol, deliver)

    return symbol


def format_deliver(deliver):
    if deliver == 'swap':
        return 'SWAP'

    err = 'Okex deliver {} not implemented'.format(deliver)
    raise NotImplementedError(err)


def resume_deliver(date):
    if date == 'SWAP':
        return 'swap'

    err = 'Okex deliver {} not implemented'.format(date)
    raise NotImplementedError(err)


def get_inst_type(symbol):
    if '.' not in symbol:
        return 'SPOT'
    if symbol.endswith('.swap'):
        return 'SWAP'
    raise NotImplementedError('okex inst_type for {} not implemented'.format(symbol))


def get_utc_time():
    now = datetime.datetime.fromtimestamp(time.time() - 28800)
    t = now.isoformat("T", "milliseconds")
    return t + "Z"


def build_req(method, url, path, utc_time, api_key, secret_key, pass_phs, params=None, body=None):
    if params is not None:
        path += '?' + parse_params(params)

    sign_message = f'{utc_time}{method}{path}'
    body_json = None

    if body is not None:
        body_json = json.dumps(body)
        sign_message += body_json

    mac = hmac.new(bytes(secret_key, encoding='utf8'), bytes(sign_message, encoding='utf-8'), digestmod='sha256')
    d = mac.digest()
    sign = base64.b64encode(d)

    headers = {
        'Content-Type':        'application/json',
        'OK-ACCESS-KEY':        api_key,
        'OK-ACCESS-SIGN':       sign,
        'OK-ACCESS-TIMESTAMP':  utc_time,
        'OK-ACCESS-PASSPHRASE': pass_phs,
    }
    req = requests.Request(method, url+path, headers, data=body_json)
    return req


def build_ping():
    return 'ping'


rest_host = 'https://www.okex.com'
ws_host = 'wss://ws.okx.com:8443/ws/v5/public'
ws_host_user = 'wss://ws.okx.com:8443/ws/v5/private'

order_type_map = {
    OrderType.Limit: 'limit',
    OrderType.PostOnly: 'post_only',
    OrderType.Market: 'market',
    OrderType.IOC: 'ioc',
}
status_map = {
    'live':             OrderStatus.Pending,
    'canceled':         OrderStatus.Canceled,
    'partially_filled': OrderStatus.PartFilled,
    'filled':           OrderStatus.FullyFilled,
}
deal_status = {'filled', 'partially_filled'}


class NullCls:
    def __init__(self, *args, **kwargs):
        pass

    def __getattr__(self, item):
        return self._func

    def _func(self, *args, **kwargs):
        pass


add_factory(Exchange.Okex, OkexChannel, OkexHandler, OkexApi, OkexSpi, OkexFunctions)


"""
1. 检查覆盖get_contract_size()
2. usd永续的contract_size()错误！

"""


if __name__ == '__main__':
    from quant.utils import set_test_mode
    from quant.markets.functions import set_quick_mode, get_value_factor
    import keys

    set_test_mode()
    set_quick_mode()

    def try_market():
        # a = get_value_factor(Exchange.Okex, 'btc/usdt.swap')
        # print(a)

        from quant.markets import Markets
        from utils import print_raw
        fn = OkexFunctions()
        b = fn._get_inst_attr('SPOT')
        for i in b.items():
            print(i)

    def try_api():
        api = OkexApi(keys.get_key('okex_wzz'))
        api.account.info_engine.show_user_data()
        api.account.info_engine.show_problems()

        config.rest_timeout = 3

        # buy_price = 1000
        # order = Order('eth/usdt.swap', 'buy', buy_price, 10)
        # api.place_order(order)

        # api.add_symbol('doge/usdt')

        # api.query_all_balance()
        # api.query_open_orders()
        # api.query_all_open_orders()
        # api.query_margin()
        # api.account.position = {'btc/usdt.swap': {'asdfjklasd;fjkl'}}
        # api.query_all_position()
        # api.query_position('doge/usdt.swap')

        # order = Order('luna/usdt.swap', 'buy', 8, 1)
        # api.place_order(order)

        # api.set_executor_mode()

        # while True:
        #     input(':')
        #     config.rest_timeout = 3
        #     cid = api.place_order(Order('doge/usdt.swap', 'buy', 0.1, 1, offset=Offset.Close))
        #     input(':')
        #     config.rest_timeout = 0.1
        #     if cid in api.account.orders:
        #         order = api.account.orders[cid]
        #         api.cancel_order(order)
        #     else:
        #         print('cid not exist')

        # while True:
        #     api.cancel_order(Order(symbol='doge/usdt.swap', client_id='a', order_id='360012137279954947'))
        #     input(':')

        # api.query_all_open_orders()
        # input(':')
        # for o in api.account.orders.values():
        #     api.cancel_order(o)

        while True:
            input(':')

    def try_spi():
        key = keys.get_key('okex_1_v5')
        spi = OkexSpi(key)

        spi.account.info_engine.show_user_data()

        spi.add_symbol('eth/usdt.swap')
        while True:
            input(':')
            spi.add_symbol('doge/usdt.swap')

        while True:
            input(':')

    fn = OkexFunctions()
    a = fn.get_all_ticker()
    a = {s: d for s, d in a.items() if 'usdt.swap' in s}
    a = [(s, d) for s, d in a.items()]
    a.sort(key=lambda x: x[1]['vol'], reverse=True)
    # for i in a:
    #     print(i)

    # n = 0
    # for i in a:
    #     print(n)
    #     n += 1
    #     print(i)









